from copy import deepcopy
import tensorflow as tf

def create_SAC_algorithm(variant, *args, **kwargs):
    from .sac import SAC

    algorithm = SAC(*args, **kwargs)

    return algorithm

def create_SACDM_algorithm(variant, *args, **kwargs):
    from .sac_dm import SACDM

    algorithm = SACDM(*args, **kwargs)

    return algorithm

def create_SQLDM_algorithm(variant, *args, **kwargs):
    from .sql_dm import SQLDM

    algorithm = SQLDM(*args, **kwargs)

    return algorithm

def create_SQL_algorithm(variant, *args, **kwargs):
    from .sql import SQL

    algorithm = SQL(*args, **kwargs)

    return algorithm

ALGORITHM_CLASSES = {
    'SAC': create_SAC_algorithm,
    'SQL': create_SQL_algorithm,
    'SACDM': create_SACDM_algorithm,
    'SQLDM': create_SQLDM_algorithm,
}


def get_algorithm_from_variant(variant,
                               *args,
                               **kwargs):
    algorithm_params = variant['algorithm_params']
    algorithm_type = algorithm_params['type']
    algorithm_kwargs = deepcopy(algorithm_params['kwargs'])
    algorithm = ALGORITHM_CLASSES[algorithm_type](
        variant, *args, **algorithm_kwargs, **kwargs)

    return algorithm


def sample_gumbel(shape, eps=1e-20):
    """Sample from Gumbel(0, 1)"""
    U = tf.random_uniform(shape, minval=0, maxval=1)
    return -tf.log(-tf.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature):
    """ Draw a sample from the Gumbel-Softmax distribution"""
    y = logits + sample_gumbel(tf.shape(logits))
    return tf.nn.softmax(y / temperature)


def gumbel_softmax(logits, temperature, hard=True):
    """Sample from the Gumbel-Softmax distribution and optionally discretize.
    Args:
      logits: [batch_size, n_class] unnormalized log-probs
      temperature: non-negative scalar
      hard: if True, take argmax, but differentiate w.r.t. soft sample y
    Returns:
      [batch_size, n_class] sample from the Gumbel-Softmax distribution.
      If hard=True, then the returned sample will be one-hot, otherwise it will
      be a probabilitiy distribution that sums to 1 across classes
    """
    y = gumbel_softmax_sample(logits, temperature)
    if hard:
        k = tf.shape(logits)[-1]
        # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y

def f_gan_activation(divergence,v):
    if divergence == "KLD":
        return v
    elif divergence == "RKL":
        return -tf.exp(-v)
    elif divergence == "CHI":
        return v
    elif divergence == "SQH":
        return 1-tf.exp(-v)
    elif divergence == "JSD":
        return tf.log(tf.constant(2.))-tf.log(1.0+tf.exp(-v))
    elif divergence == "GAN":
        return -tf.log(1.0+tf.exp(-v)) # log sigmoid

def f_gan_conjuate(divergence,t):
    if divergence == "KLD":
        return tf.exp(t-1)
    elif divergence == "RKL":
        return -1 -tf.log(-t)
    elif divergence == "CHI":
        return 0.25*t**2+t
    elif divergence == "SQH":
        return t/(tf.constant(1.)-t)
    elif divergence == "JSD":
        return -tf.log(2.0-tf.exp(t))
    elif divergence == "GAN":
        return  -tf.log(1.0-tf.exp(t))